#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>
#include <unistd.h>
#include <queue>
#include <unordered_map>
#include <cmath>
#include <dirent.h>
#include <string.h>
#include "acllite_dvpp_lite/ImageProc.h"
#include "acllite_om_execute/ModelProc.h"
#include "acllite_common/Queue.h"
#include "label.h"

using namespace std;
using namespace acllite;

aclrtContext context = nullptr;
uint32_t modelWidth = 224;
uint32_t modelHeight = 224;
bool exitFlag = false;
struct MsgData {
    std::shared_ptr<uint8_t> data = nullptr;
    uint32_t size = 0;
    bool videoEnd = false;
};

struct MsgOut {
    bool videoEnd = false;
    vector<InferenceOutput> inferOutputs;
};

Queue<MsgData> msgDataQueue(32);
Queue<MsgOut> msgOutQueue(32);

void GetResult(std::vector<InferenceOutput>& inferOutputs)
{
    uint32_t dataSize = inferOutputs[0].size;
    // get result from output data set
    float* outData = static_cast<float*>(inferOutputs[0].data.get());
    if (outData == nullptr) {
        LOG_PRINT("get result from output data set failed.");
        return ;
    }
    map<float, unsigned int, greater<float> > resultMap;
    for (uint32_t j = 0; j < dataSize / sizeof(float); ++j) {
        resultMap[*outData] = j;
        outData++;
    }

    uint32_t topConfidenceLevels = 5;
    double totalValue = 0.0;
    for (auto it = resultMap.begin(); it != resultMap.end(); ++it) {
        totalValue += exp(it->first);
    }

    int cnt = 0;
    for (auto it = resultMap.begin(); it != resultMap.end(); ++it) {
        // print top 5
        if (++cnt > topConfidenceLevels) {
            break;
        }
        LOG_PRINT("[INFO] top %d: index[%d] value[%lf] class[%s]", cnt, it->second,
                         exp(it->first) / totalValue, label[it->second].c_str());
    }
    outData = nullptr;
    return;
}

void* GetInput(void* arg) {
    bool ret = SetCurContext(context);
    CHECK_RET(ret, LOG_PRINT("[ERROR] set cur context for pthread  %ld failed.", pthread_self()); return NULL);
    int32_t deviceId = *(int32_t *)arg;
    ImageProc imageProcess;
    ImageData frame;
    ImageSize modelSize(modelWidth, modelHeight);
    LOG_PRINT("[INFO] start to decode...");
    uint32_t cnt = 1000;
    while(cnt) {
        frame = imageProcess.Read("../data/dog1_1024_683.jpg");
        CHECK_RET(frame.size, LOG_PRINT("[ERROR] Read image failed."); break);
        ImageData dst;
        imageProcess.Resize(frame, dst, modelSize);
        MsgData msgData;
        msgData.data = dst.data;
        msgData.size = dst.size;
        msgData.videoEnd = false;
            while (1) {
                if (msgDataQueue.Push(msgData)) {
                    break;
                }
                usleep(100);
            }
        cnt--;
    }
    MsgData msgData;
    msgData.videoEnd = true;
    while (1) {
        if (msgDataQueue.Push(msgData)) {
            break;
        }
        usleep(100);
    }
    LOG_PRINT("[INFO] preprocess add end msgData. tid : %ld", pthread_self());
    return NULL;
}
 
void* ModelExecute(void* arg) {
    bool ret = SetCurContext(context);
    CHECK_RET(ret, LOG_PRINT("[ERROR] set cur context for pthread  %ld failed.", pthread_self()); return NULL);
    ModelProc modelProcess;
    string modelPath = "../model/resnet50.om";
    ret = modelProcess.Load(modelPath);
    CHECK_RET(ret, LOG_PRINT("[ERROR] load model %s failed.", modelPath.c_str()); return NULL);
    while(1) {
        if(!msgDataQueue.Empty()) {
            MsgData msgData = msgDataQueue.Pop();
            if (msgData.videoEnd) {
                break;
            }
            else {
                ret = modelProcess.CreateInput(static_cast<void*>(msgData.data.get()), msgData.size);
                CHECK_RET(ret, LOG_PRINT("[ERROR] Create model input failed."); break);
                MsgOut msgOut;
                msgOut.videoEnd = msgData.videoEnd;
                modelProcess.Execute(msgOut.inferOutputs);
                CHECK_RET(ret, LOG_PRINT("[ERROR] model execute failed."); break);
                while (1) {
                    if (msgOutQueue.Push(msgOut)) {
                        break;
                    }
                    usleep(100);
                }
            }
        }
    }
    modelProcess.DestroyResource();
    MsgOut msgOut;
    msgOut.videoEnd = true;
    while (1) {
        if (msgOutQueue.Push(msgOut)) {
            break;
        }
        usleep(100);
    }
    LOG_PRINT("[INFO] infer msg end. tid : %ld", pthread_self());
    return  NULL;
}

void* PostProcess(void* arg) {
    while(1) {
        if(!msgOutQueue.Empty()) {
            MsgOut msgOut = msgOutQueue.Pop();
            usleep(100);
            if (msgOut.videoEnd) {
                break;
            }
            GetResult(msgOut.inferOutputs);
        }
    }
    LOG_PRINT("[INFO] *************** all get done ***************");
    exitFlag = true;
    return  NULL;
}

int main() {
    int32_t deviceId = 0;
    AclLiteResource aclResource(deviceId);
    bool ret = aclResource.Init();
    CHECK_RET(ret, LOG_PRINT("[ERROR] InitACLResource failed."); return 1);
    context = aclResource.GetContext();    
 
    pthread_t preTids, exeTids, posTids;
    pthread_create(&preTids, NULL, GetInput, (void*)&deviceId);
    pthread_create(&exeTids, NULL, ModelExecute, NULL);
    pthread_create(&posTids, NULL, PostProcess, NULL);

    pthread_detach(preTids);
    pthread_detach(exeTids);
    pthread_detach(posTids);
 
    while(!exitFlag) {
        sleep(10);
    }
    aclResource.Release();
    return 0;
}
